import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import enformer
from enformer_pytorch.data import str_to_one_hot


class SeqDataset(Dataset):

    def __init__(self, seqs, rc=False):
        super().__init__()
        self.is_labeled=False
        self.rng = np.random.RandomState()
        self.rc = rc
        if isinstance(seqs, pd.DataFrame):
            self.seqs = seqs.iloc[:, 0]
            self.labels = torch.Tensor(seqs.values[:, 1:].astype(np.float32))  # N, n_tasks
            self.is_labeled=True

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq = self.seqs[idx]
        if self.is_labeled:
            label = self.labels[idx]  
            if self.rc and self.rng.randint(2):
                seq = "".join([RC_HASH[base] for base in reversed(seq)])
                
        seq = str_to_one_hot(seq).swapaxes(1,2).squeeze()
        if self.is_labeled:
            return seq, label
        else:
            return seq


class LightningModel(pl.LightningModule):

    def __init__(self, model_type="Enformer", lr=1e-4, loss='poisson', **kwargs):
        super().__init__()

        self.save_hyperparameters(ignore=["model"])
        self.model = getattr(enformer, model_type)
        self.lr=lr
        self.model = self.model(**kwargs)
        self.loss_type='loss'
        if loss=='poisson':
            self.loss = nn.PoissonNLLLoss(log_input=False, full=True)
        else:
            self.loss = nn.MSELoss()

    def forward(self, x):
        if isinstance(x, tuple):
            x = x[0]
        # If x is a list of strings, convert it into a one-hot encoded tensor
        if isinstance(x, list):
            if isinstance(x[0], str):
                x = str_to_one_hot(x).swapaxes(1, 2)
            else:
                # If x is a list (input, target) use only the input
                x = x[0]

        out = self.model(x)
        if self.loss_type=='poisson':
            out = torch.exp(out)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        self.log("train_loss", loss, logger=False,on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        self.log("val_loss", loss, logger=False, on_step=False, on_epoch=True)
        return loss

    def validation_epoch_end(self, output):
        print("val_loss",  torch.mean(torch.Tensor(output)).item())

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def train_on_dataset(self, train_data, val_data, device=0, batch_size=512, num_workers=1):
        torch.set_float32_matmul_precision("medium")
        trainer = pl.Trainer(
            max_epochs=10,
            accelerator="gpu",
            devices=[device],
            logger=None,
            callbacks=[ModelCheckpoint(monitor="val_loss", mode="min")],
        )
        # Make dataloaders
        train_data = DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers],
        )
        val_data = DataLoader(
            val_data,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
        )
        trainer.fit(model=self, train_dataloaders=train_data, val_dataloaders=val_data)
        return trainer

    def predict_on_dataset(
        self,
        dataset,
        device=0,
        num_workers=1,
        batch_size=512,
    ):
        torch.set_float32_matmul_precision("medium")
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        trainer = pl.Trainer(accelerator="gpu", devices=[device], logger=None)
        return torch.concat(trainer.predict(self, dataloader)).cpu().detach().numpy().squeeze()